{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "O_putZlTpVgI"
      },
      "source": [
        "<span style=\"color: red;\">Requirement when running in Goolge Colab</span>"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "053w39xVpVgJ"
      },
      "outputs": [],
      "source": [
        "!pip install diffusers"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Z4BdpYtFpVgK"
      },
      "source": [
        "#  Chapter 4 - Attention Layer Breakdown\n",
        "\n",
        "Attention layers, a key component of the transformer architecture, were introduced to Stable Diffusion to enhance its ability to capture long-range dependencies and context in image generation. These layers, inspired by the seminal \"Attention Is All You Need\" (https://arxiv.org/abs/1706.03762) paper by Vaswani et al. (2017), allow the model to focus on relevant parts of the input when generating each part of the output. In Stable Diffusion, attention layers play a crucial role in connecting text prompts to visual elements. Understanding and breaking down these attention layers is vital for advanced users and researchers who aim to modify or optimise the model's behavior. By overriding and analysing the attention mechanism, one can gain insights into how the model interprets prompts and constructs images, paving the way for targeted improvements, custom behaviors, or even novel applications of the technology."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "k22FQE91pVgL"
      },
      "source": [
        "The original architecture is only displaying the Cross Attention which the Key is the prompt condition and the Query is the latents, but in reality for each Cross-Attention in Stable Diffusion there is a prior Self-Attention which the Key and Query are the image latents."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PPrvMBwhpVgL"
      },
      "source": [
        "![image.png](assets/image5.png)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LzHzrZvTpVgL"
      },
      "source": [
        "This part, is technically the same as previous chapter, but for simplification we removed the breakdown and shorten it, all individual elemtns of the architecture can also be accessed through the pipe itself, for example the u-net can be accessed by pipe.unet"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Tc4seE4YpVgL"
      },
      "outputs": [],
      "source": [
        "import warnings\n",
        "warnings.filterwarnings(\"ignore\")\n",
        "from diffusers import StableDiffusionPipeline\n",
        "import torch\n",
        "import matplotlib.pyplot as plt\n",
        "from typing import Optional\n",
        "from tqdm import tqdm\n",
        "\n",
        "\n",
        "model_id = \"stabilityai/stable-diffusion-2-1-base\"\n",
        "\n",
        "pipe = StableDiffusionPipeline.from_pretrained(\"stabilityai/stable-diffusion-2-1-base\")\n",
        "pipe = pipe.to(\"cuda\")\n",
        "\n",
        "prompt = \"A photo of a woman, straight hair, light blonde and pink hair, smiling expression, grey background\"\n",
        "\n",
        "prompt_embeds = pipe.encode_prompt(prompt=prompt, device=\"cuda\", num_images_per_prompt=1, do_classifier_free_guidance=True, negative_prompt=\"\")\n",
        "prompt_embeds_combined = torch.cat([prompt_embeds[1], prompt_embeds[0]])\n",
        "\n",
        "\n",
        "latents = torch.randn((1, pipe.unet.in_channels, pipe.unet.config.sample_size, pipe.unet.config.sample_size), generator=torch.Generator().manual_seed(220)).to(\"cuda\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BIA3tEjnpVgM"
      },
      "source": [
        "There is unfortunately not a clean way of doing this than recursively extracting the attention layers based of their name given and overwriting the forward function of those layers with our custom layer\n",
        "\n",
        "First we expand a version of a attention layer calculation so we could replace with the main caclulation in the model layers so we will have access to the relevant details. If you are not familiar with self-attention or cross-attention layer cacluation I highly recommend this 3Blue1Brown YouTube video https://youtu.be/eMlx5fFNoYc explaining in details of how the caclulation works. The default calcuation function in Stable Diffusion is a optimised c version which has python binding, so the pure python version will introduce a slight overhead to the current performance of sampling"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "LgroqsAWpVgM"
      },
      "outputs": [],
      "source": [
        "def contextual_forward(self):\n",
        "\n",
        "    def forward_modified(\n",
        "        hidden_states: torch.FloatTensor,\n",
        "        encoder_hidden_states: Optional[torch.FloatTensor] = None,\n",
        "        attention_mask: Optional[torch.FloatTensor] = None,\n",
        "        temb: Optional[torch.FloatTensor] = None,\n",
        "    ) -> torch.FloatTensor:\n",
        "\n",
        "            residual = hidden_states\n",
        "\n",
        "            input_ndim = hidden_states.ndim\n",
        "\n",
        "            if input_ndim == 4:\n",
        "                batch_size, channel, height, width = hidden_states.shape\n",
        "                hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)\n",
        "\n",
        "            batch_size, _, _ = (\n",
        "                hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape\n",
        "            )\n",
        "\n",
        "            if self.group_norm is not None:\n",
        "                hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)\n",
        "\n",
        "            query = self.to_q(hidden_states)\n",
        "\n",
        "            if encoder_hidden_states is None:\n",
        "                encoder_hidden_states = hidden_states\n",
        "            elif self.norm_cross:\n",
        "                encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states)\n",
        "\n",
        "            key = self.to_k(encoder_hidden_states)\n",
        "            value = self.to_v(encoder_hidden_states)\n",
        "\n",
        "            query = self.head_to_batch_dim(query)\n",
        "            key = self.head_to_batch_dim(key)\n",
        "            value = self.head_to_batch_dim(value)\n",
        "\n",
        "            attention_scores = self.scale * torch.bmm(query, key.transpose(-1, -2))\n",
        "\n",
        "            attention_probs = attention_scores.softmax(dim=-1)\n",
        "            del attention_scores\n",
        "\n",
        "            hidden_states = torch.bmm(attention_probs, value)\n",
        "            hidden_states = self.batch_to_head_dim(hidden_states)\n",
        "            del attention_probs\n",
        "\n",
        "            hidden_states = self.to_out[0](hidden_states)\n",
        "\n",
        "            if input_ndim == 4:\n",
        "                hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)\n",
        "\n",
        "            if self.residual_connection:\n",
        "                hidden_states = hidden_states + residual\n",
        "\n",
        "            hidden_states = hidden_states / self.rescale_output_factor\n",
        "\n",
        "            return hidden_states\n",
        "\n",
        "    return forward_modified"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Z_w58AnJpVgN"
      },
      "source": [
        "This is unfortunately the messy part that requires to iterate through all the children of the model to find the relevant Attention layers"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0so-CMR-pVgN"
      },
      "outputs": [],
      "source": [
        "\n",
        "def apply_forward_function(unet, child = None):\n",
        "    if child == None:\n",
        "        children = unet.named_children()\n",
        "        for child in children:\n",
        "            apply_forward_function(unet, child[1])\n",
        "    else:\n",
        "        if child.__class__.__name__ == 'Attention':\n",
        "\n",
        "            child.forward = contextual_forward(child)\n",
        "        elif hasattr(child, 'children'):\n",
        "            for sub_child in child.children():\n",
        "                apply_forward_function(unet, sub_child)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "N0EkHlGTpVgN"
      },
      "source": [
        "now before running inference steps for sampling we overwrite our current attention layer implementation in the model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fHLSnqLipVgO"
      },
      "outputs": [],
      "source": [
        "num_inference_steps = 50\n",
        "guidance_scale = 7.5\n",
        "pipe.scheduler.set_timesteps(num_inference_steps, device=\"cuda\")\n",
        "timesteps = pipe.scheduler.timesteps\n",
        "\n",
        "with torch.no_grad():\n",
        "\n",
        "    # prior to running our model we replace the unet forward function for all attention layer to our custom version\n",
        "    apply_forward_function(pipe.unet)\n",
        "\n",
        "    for i, t in tqdm(enumerate(timesteps), total=len(timesteps), desc=\"Inference steps\"):\n",
        "\n",
        "        latent_model_input = torch.cat([latents] * 2)\n",
        "\n",
        "        noise_pred = pipe.unet(\n",
        "            latent_model_input,\n",
        "            t,\n",
        "            encoder_hidden_states=prompt_embeds_combined,\n",
        "            cross_attention_kwargs=None,\n",
        "            return_dict=False,\n",
        "        )[0]\n",
        "\n",
        "\n",
        "        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n",
        "        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n",
        "\n",
        "        latents = pipe.scheduler.step(noise_pred, t, latents, return_dict=False)[0]\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "evUvoz5MpVgO"
      },
      "source": [
        "The results should theoratically look exactly the same as the previous chapter as we only overriden the attention layers forward function for our breakdown"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "VKDAeoMkpVgO"
      },
      "outputs": [],
      "source": [
        "with torch.no_grad():\n",
        "    image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0]\n",
        "    image_np = image.squeeze(0).float().permute(1,2,0).detach().cpu()\n",
        "    image_np = image_np - image_np.min()\n",
        "    image_np = image_np / image_np.max()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NKiqDlUvpVgP"
      },
      "outputs": [],
      "source": [
        "plt.imshow(image_np)"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "L4",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.10.11"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
